import os

import torch

# from ase.units
kB = 8.617330337217213e-05
fs = 0.09822694788464063
Hartree = 27.211386024367243
half_Hartree = 0.5 * Hartree
Bohr = 0.5291772105638411
Bohr_inv = 1 / Bohr


def get_masses(device="cpu"):
    """Atomic masses from `ase.data.atomic_masses`"""
    atomic_masses = torch.tensor([
        0.0,
        1.008,
        4.002602,
        6.94,
        9.0121831,
        10.81,
        12.011,
        14.007,
        15.999,
        18.99840316,
        20.1797,
        22.98976928,
        24.305,
        26.9815385,
        28.085,
        30.973762,
        32.06,
        35.45,
        39.948,
        39.0983,
        40.078,
        44.955908,
        47.867,
        50.9415,
        51.9961,
        54.938044,
        55.845,
        58.933194,
        58.6934,
        63.546,
        65.38,
        69.723,
        72.63,
        74.921595,
        78.971,
        79.904,
        83.798,
        85.4678,
        87.62,
        88.90584,
        91.224,
        92.90637,
        95.95,
        97.90721,
        101.07,
        102.9055,
        106.42,
        107.8682,
        112.414,
        114.818,
        118.71,
        121.76,
        127.6,
        126.90447,
        131.293,
        132.90545196,
        137.327,
        138.90547,
        140.116,
        140.90766,
        144.242,
        144.91276,
        150.36,
        151.964,
        157.25,
        158.92535,
        162.5,
        164.93033,
        167.259,
        168.93422,
        173.054,
        174.9668,
        178.49,
        180.94788,
        183.84,
        186.207,
        190.23,
        192.217,
        195.084,
        196.966569,
        200.592,
        204.38,
        207.2,
        208.9804,
        208.98243,
        209.98715,
        222.01758,
        223.01974,
        226.02541,
        227.02775,
        232.0377,
        231.03588,
        238.02891,
        237.04817,
        244.06421,
        243.06138,
        247.07035,
        247.07031,
        251.07959,
        252.083,
        257.09511,
        258.09843,
        259.101,
        262.11,
        267.122,
        268.126,
        271.134,
        270.133,
        269.1338,
        278.156,
        281.165,
        281.166,
        285.177,
        286.182,
        289.19,
        289.194,
        293.204,
        293.208,
        294.214,
    ])
    return atomic_masses.to(device)


def get_gfn1_rep(device="cpu"):
    """Parameters for GFN1 repulsion function."""
    gfn1_alpha = torch.tensor([
        0.000001,
        2.209700,
        1.382907,
        0.671797,
        0.865377,
        1.093544,
        1.281954,
        1.727773,
        2.004253,
        2.507078,
        3.038727,
        0.704472,
        0.862629,
        0.929219,
        0.948165,
        1.067197,
        1.200803,
        1.404155,
        1.323756,
        0.581529,
        0.665588,
        0.841357,
        0.828638,
        1.061627,
        0.997051,
        1.019783,
        1.137174,
        1.188538,
        1.399197,
        1.199230,
        1.145056,
        1.047536,
        1.129480,
        1.233641,
        1.270088,
        1.153580,
        1.335287,
        0.554032,
        0.657904,
        0.760144,
        0.739520,
        0.895357,
        0.944064,
        1.028240,
        1.066144,
        1.131380,
        1.206869,
        1.058886,
        1.026434,
        0.898148,
        1.008192,
        0.982673,
        0.973410,
        0.949181,
        1.074785,
        0.579919,
        0.606485,
        1.311200,
        0.839861,
        0.847281,
        0.854701,
        0.862121,
        0.869541,
        0.876961,
        0.884381,
        0.891801,
        0.899221,
        0.906641,
        0.914061,
        0.921481,
        0.928901,
        0.936321,
        0.853744,
        0.971873,
        0.992643,
        1.132106,
        1.118216,
        1.245003,
        1.304590,
        1.293034,
        1.181865,
        0.976397,
        0.988859,
        1.047194,
        1.013118,
        0.964652,
        0.998641,
    ])
    # Zeff
    gfn1_Zeff = torch.tensor([
        0.000000,
        1.116244,
        0.440231,
        2.747587,
        4.076830,
        4.458376,
        4.428763,
        5.498808,
        5.171786,
        6.931741,
        9.102523,
        10.591259,
        15.238107,
        16.283595,
        16.898359,
        15.249559,
        15.100323,
        17.000000,
        17.153132,
        20.831436,
        19.840212,
        18.676202,
        17.084130,
        22.352532,
        22.873486,
        24.160655,
        25.983149,
        27.169215,
        23.396999,
        29.000000,
        31.185765,
        33.128619,
        35.493164,
        36.125762,
        32.148852,
        35.000000,
        36.000000,
        39.653032,
        38.924904,
        39.000000,
        36.521516,
        40.803132,
        41.939347,
        43.000000,
        44.492732,
        45.241537,
        42.105527,
        43.201446,
        49.016827,
        51.718417,
        54.503455,
        50.757213,
        49.215262,
        53.000000,
        52.500985,
        65.029838,
        46.532974,
        48.337542,
        30.638143,
        34.130718,
        37.623294,
        41.115870,
        44.608445,
        48.101021,
        51.593596,
        55.086172,
        58.578748,
        62.071323,
        65.563899,
        69.056474,
        72.549050,
        76.041625,
        55.222897,
        63.743065,
        74.000000,
        75.000000,
        76.000000,
        77.000000,
        78.000000,
        79.000000,
        80.000000,
        81.000000,
        79.578302,
        83.000000,
        84.000000,
        85.000000,
        86.000000,
    ])
    gfn1_repa = gfn1_alpha.pow(0.5) * Bohr_inv**0.75
    gfn1_repb = gfn1_Zeff * (0.5 * Hartree * Bohr) ** 0.5
    return gfn1_repa.to(device), gfn1_repb.to(device)


def get_r4r2(device="cpu"):
    """r4r2 parameter for DFT-D3"""
    ## https://github.com/dftd4/dftd4/blob/main/src/dftd4/data/r4r2.f90
    sqrt_z_r4_over_r2 = [
        0.0,
        8.0589,
        3.4698,
        29.0974,
        14.8517,
        11.8799,
        7.8715,
        5.5588,
        4.7566,
        3.8025,
        3.1036,
        26.1552,
        17.2304,
        17.7210,
        12.7442,
        9.5361,
        8.1652,
        6.7463,
        5.6004,
        29.2012,
        22.3934,
        19.0598,
        16.8590,
        15.4023,
        12.5589,
        13.4788,
        12.2309,
        11.2809,
        10.5569,
        10.1428,
        9.4907,
        13.4606,
        10.8544,
        8.9386,
        8.1350,
        7.1251,
        6.1971,
        30.0162,
        24.4103,
        20.3537,
        17.4780,
        13.5528,
        11.8451,
        11.0355,
        10.1997,
        9.5414,
        9.0061,
        8.6417,
        8.9975,
        14.0834,
        11.8333,
        10.0179,
        9.3844,
        8.4110,
        7.5152,
        32.7622,
        27.5708,
        23.1671,
        21.6003,
        20.9615,
        20.4562,
        20.1010,
        19.7475,
        19.4828,
        15.6013,
        19.2362,
        17.4717,
        17.8321,
        17.4237,
        17.1954,
        17.1631,
        14.5716,
        15.8758,
        13.8989,
        12.4834,
        11.4421,
        10.2671,
        8.3549,
        7.8496,
        7.3278,
        7.4820,
        13.5124,
        11.6554,
        10.0959,
        9.7340,
        8.8584,
        8.0125,
        29.8135,
        26.3157,
        19.1885,
        15.8542,
        16.1305,
        15.6161,
        15.1226,
        16.1576,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        0.0000,
        5.4929,
        6.7286,
        6.5144,
        10.9169,
        10.3600,
        9.4723,
        8.6641,
    ]

    r4r2 = (0.5 * torch.tensor(sqrt_z_r4_over_r2) * torch.arange(len(sqrt_z_r4_over_r2)).sqrt()).sqrt()
    return r4r2.to(device)


def get_dftd3_param(device="cpu"):
    """Collection of parameters for DFT-D3 model"""
    dirname = os.path.dirname(__file__)
    filename = os.path.join(dirname, "dftd3_data.pt")
    if not os.path.exists(filename):
        raise FileNotFoundError(f"dftd3_data.pt not found in {dirname}.")
    param = torch.load(filename, map_location=device, weights_only=True)
    assert isinstance(param, dict)
    assert "c6ab" in param
    assert "r4r2" in param
    assert "rcov" in param
    assert "cnmax" in param
    return param
